
import numpy as np
from numpy.linalg import norm
from scipy.stats import skew
import matplotlib.pyplot as plt


DATA_PATH = "Path to extracted faetures from dino train_dino.npz"
SAVE_PATH =  "Path to save pseudolabeled results, pseudolabel_results.npz"

data = np.load(DATA_PATH, allow_pickle=True)

X_full_raw = data["features"].astype(np.float32)
y_full = data["labels"].astype(int)

image_paths = data["paths"]

print("Loaded dataset:", X_full_raw.shape)

X_full_norm = X_full_raw / np.linalg.norm(X_full_raw, axis=1, keepdims=True)


np.random.seed(42)

N = len(X_full_norm)
indices = np.random.permutation(N)

n_labeled = int(0.3 * N)

labeled_idx = indices[:n_labeled]
unlabeled_idx = indices[n_labeled:]

# Features
X_l_raw = X_full_raw[labeled_idx]
X_u_raw = X_full_raw[unlabeled_idx]

X_l = X_full_norm[labeled_idx]
X_u = X_full_norm[unlabeled_idx]

# Labels
y_l = y_full[labeled_idx]
y_u_true = y_full[unlabeled_idx]

paths_l = image_paths[labeled_idx]
paths_u = image_paths[unlabeled_idx]

classes = np.unique(y_l)

print("Initial labeled:", len(X_l))
print("Initial unlabeled:", len(X_u))

def geometric_median(X, eps=1e-5):

    y = X.mean(axis=0)

    while True:

        D = np.linalg.norm(X - y, axis=1)
        nonzeros = D != 0

        if not np.any(nonzeros):
            return y

        D_inv = 1 / D[nonzeros]

        T = (D_inv[:, None] * X[nonzeros]).sum(axis=0) / D_inv.sum()

        if norm(y - T) < eps:
            return T

        y = T

prototypes = {}
sigma = {}

for c in classes:

    X_c = X_l[y_l == c]

    proto = geometric_median(X_c)

    prototypes[c] = proto

    sigma[c] = np.median(np.linalg.norm(X_c - proto, axis=1)) + 1e-8

confidence = np.zeros(len(X_u))
predicted_labels = np.zeros(len(X_u), dtype=int)

for i, x in enumerate(X_u):

    dists = []

    for c in classes:

        d = np.linalg.norm(x - prototypes[c]) / sigma[c]
        dists.append(d)

    dists = np.array(dists)

    best_idx = np.argmin(dists)
    second_idx = np.argsort(dists)[1]

    d1 = dists[best_idx]
    d2 = dists[second_idx]

    predicted_labels[i] = classes[best_idx]

    raw_conf = (d2 - d1) / (d1 + d2 + 1e-12)

    confidence[i] = raw_conf

conf_min = confidence.min()
conf_max = confidence.max()

confidence = (confidence - conf_min) / (conf_max - conf_min + 1e-12)

print("\nConfidence range:", confidence.min(), confidence.max())

plt.figure(figsize=(6,4))
plt.hist(confidence, bins=40)
plt.title("Global Confidence Distribution")
plt.xlabel("Confidence")
plt.ylabel("Frequency")
plt.show()

accepted_mask = np.zeros(len(X_u), dtype=bool)

print("\nClass-wise Thresholds")

for c in classes:

    class_mask = predicted_labels == c
    scores = confidence[class_mask]

    if len(scores) < 10:
        continue

    mu = np.mean(scores)
    std = np.std(scores)
    sk = skew(scores)

    tau = mu + (1 + max(0, sk)) * std
    tau = np.clip(tau, 0, 1)

    selected = class_mask & (confidence >= tau)

    accepted_mask |= selected

    print("\nClass", c)
    print("Mean:", round(mu,4))
    print("Std :", round(std,4))
    print("Skew:", round(sk,4))
    print("Threshold:", round(tau,4))
    print("Selected samples:", np.sum(selected))

    plt.figure(figsize=(6,4))
    plt.hist(scores, bins=40)
    plt.axvline(tau, linestyle='--', label=f"Threshold={tau:.3f}")
    plt.title(f"Confidence Distribution - Class {c}")
    plt.xlabel("Confidence")
    plt.ylabel("Frequency")
    plt.legend()
    plt.show()

num_selected = np.sum(accepted_mask)

correct = np.sum(predicted_labels[accepted_mask] == y_u_true[accepted_mask])

precision = correct / num_selected if num_selected > 0 else 0

print("\nSelected pseudolabeled samples:", num_selected)
print("Pseudolabel precision:", round(precision,4))


pseudo_X_raw = X_u_raw[accepted_mask]
pseudo_y = predicted_labels[accepted_mask]

pseudo_paths = paths_u[accepted_mask]

X_combined_raw = np.vstack([X_l_raw, pseudo_X_raw])
y_combined = np.concatenate([y_l, pseudo_y])

combined_paths = np.concatenate([paths_l, pseudo_paths])

print("\nFinal training size:", len(X_combined_raw))

np.savez(
    SAVE_PATH,

    # original labeled
    original_labeled_features=X_l_raw,
    original_labeled_labels=y_l,
    original_labeled_paths=paths_l,

    # pseudolabeled
    pseudolabeled_features=pseudo_X_raw,
    pseudolabeled_labels=pseudo_y,
    pseudolabeled_paths=pseudo_paths,

    # combined dataset
    combined_features=X_combined_raw,
    combined_labels=y_combined,
    combined_paths=combined_paths,

    pseudolabel_precision=precision
)

print("Saved results to:", SAVE_PATH)
